Decision Trees

We have described two types of machine learning algorithms. Linear approaches (generalized linear models (Logistic Regression), discriminant analysis (QDA, LDA)) and a smoothing approach (k-nearest neighbors). The linear approaches were limiting in that the partition of the prediction space had to be linear (in the case of QDA, quadratic). A limitation of the smoothing approach is that with a large number of predictors, we run into the problem of the curse of dimensionality.

The Curse of Dimensionality

A useful way of understanding the curse of dimensionality is by considering how large we have to make a neighborhood/window to include a given percentage of the data. For example, suppose we have one continuous predictor with equally spaced points in the [0,1] interval and we want to create windows that include 1/10-th of the data. Then it’s easy to see that our windows have to be of size 0.1:

rafalib::mypar()
x <- seq(0, 1, len=100)
y <- rep(1, 100)
plot(x, y, xlab = "", ylab = "", cex = 0.25, yaxt = "n", xaxt = "n", type = "n")
lines(x[c(15,35)], y[c(15,35)], col = "blue", lwd = 3)
points(x,y, cex = 0.25)
points(x[25],y[25], col = "blue", cex = 0.5, pch = 4)
text(x[c(15,35)], y[c(15,35)], c("[","]"))

Now, for two predictors, if we decide to keep the neighborhood just as small, 10% for each dimension, we include only 1 point:

or, if we want to include 10% of the data we need to increase the window size to \(\sqrt{10}\):

To include 10% of the data in a case with \(p\) features we need an interval for each predictor that covers \(0.10^{1/p}\) of the total. This proportion gets close to 1 (including all the data and no longer smoothing) quickly:

Here we look at a set of elegant and versatile methods that adapt to higher dimensions and also allow these regions to take more complex shapes, but still produce models that are interpretable. These are very popular, well-known and studied methods. We will concentrate on Regression and Decision Trees and their extension to Random Forests.

Regression Trees

Consider the olives dataset below. We show two measured predictors, linoleic (percent linoleic acid of sample) and eicosenoic (percent eicosenoic acid of sample). Suppose we wanted to predict the olive’s region using these two predictors.

olives <- read.csv("https://raw.githubusercontent.com/datasciencelabs/data/master/olive.csv", as.is = TRUE) %>% tbl_df
names(olives)[1] <- "province"
region_names <- c("Southern Italy", "Sardinia", "Northern Italy")
olives <- olives %>% mutate(Region = factor(region_names[Region]))

p <- olives %>% ggplot(aes(eicosenoic, linoleic, fill = Region)) +
     geom_point(pch = 21)
p

Note that we can describe a classification algorithm that would work pretty much perfectly:

p <- p + geom_vline(xintercept = 6.5) + 
         geom_segment(x = -2, y = 1053.5, xend = 6.5, yend = 1053.5)
p

The prediction algorithm inferred from the figure above is what we call a decision tree. If eicosnoic is larger than 6.5, predict Southern Italy. If not, then if linoleic is larger than 1,054, predict Sardinia and Northern Italy otherwise. We can draw this decision tree like this:

Decision trees like this are often used in practice. For example, to decide if a person is at risk of having a heart attack, doctors use the following:

The general idea of the methods we are describing is to define an algorithm that uses data to create these tress. Regression and decision trees operate by predicting an outcome variable \(Y\) by partitioning feature (predictor) space.

Regression Trees

Let’s start with the case of a continuous outcome. The general idea here is to build a decision tree and at the end of each node we will have a different prediction \(\hat{Y}\) for the outcome \(Y\).

The regression tree model then:

  1. Partitions space into \(J\) non-overlapping regions, \(R_1, R_2, \ldots, R_J\).
  2. For every observation that falls within region \(R_j\), predict the response as the mean of responses for training observations in \(R_j\).

The important observation is that Regression Trees create partitions recursively.

For example, consider finding a good predictor \(j\) to partition space along its axis. A recursive algorithm would look like this:

Find predictor \(j\) and value \(s\) that minimize RSS:

\[ \sum_{i:\, x_i \in R_1(j,s))} (y_i - \hat{y}_{R_1})^2 + \sum_{i:\, x_i \in R_2(j,s))} (y_i - \hat{y}_{R_2})^2 \]

Where \(R_1\) and \(R_2\) are regions resulting from splitting observations on predictor \(j\) and value \(s\):

\[ R_1(j,s) = \{X|X_j < s\} \text{ and } R_2(j,s) \{X|X_j \geq s\} \]

This is then applied recursively to regions \(R_1\) and \(R_2\). Within each region a prediction is made using \(\hat{y}_{R_j}\) which is the mean of the response \(Y\) of observations in \(R_j\).

Let’s take a look at what this algorithm does on the Boston Housing data set. This dataset contains information collected by the U.S Census Service concerning housing in the area of Boston, MA in the 1970s. It was obtained from the StatLib archive. It contains information including medv (median value of owner-occupied homes in $1000’s), lstat (% of individuals with lower socioeconomic status), rm (average number of rooms per dwelling), and dis (weighted distances to five Boston employment centres), among others with a total of 14 variables.

The dataset is small in size with only 506 cases, but we’ll use it for educational purposes.

We can use the tree package and use the tree function to fit the decision tree and plot the tree.

library(tree)
library(MASS)

set.seed(1)

# Randomly sample half of the data for training
train = sample(1:nrow(Boston), nrow(Boston)/2)

# Fit a regression tree using all of the available predictors
fit = tree(medv ~ ., Boston, subset = train)  

# Print a summary of the tree
summary(fit)
## 
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "rm"    "lstat" "crim"  "age"  
## Number of terminal nodes:  7 
## Residual mean deviance:  10.38 = 2555 / 246 
## Distribution of residuals:
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## -10.1800  -1.7770  -0.1775   0.0000   1.9230  16.5800
# Use tree for prediction
preds <- predict(fit, newdata = Boston[-train,])
test = Boston[-train, "medv"]
plot(fit, type = "uniform")
text(fit, cex = 1)

The tree suggests that houses with a higher number of rooms correspond to more expensive houses and predicts a median house price of $45,380 for a house with more than 8 rooms.

The idea behind the regression tree is that outcome \(Y\) is estimated (or predicted) to be it’s mean within each of the data partitions. Think of it as the conditional mean of \(Y\) where conditioning is given by this region partitioning.

We can also use the predictions made to calculate MSE (mean square error) to compare models.

plot(preds, test)
abline(0,1)

mean((preds-test)^2)
## [1] 35.28688

The test set MSE associated with this tree is 35.29. The square root of the MSE is 5.94, indicating this model leads to test predictions that are within approximately $5,940 of the true median home value for the suburb.

The rpart package is an alternative method for fitting trees in R. It is much more feature rich, including fitting multiple cost complexities and performing cross-validation by default. It also has the ability to produce much nicer trees. Based on its default settings, it will often result in smaller trees than using the tree package.

library(rpart)
set.seed(1)

 # Fit a regression tree using all of the available predictors
fit_rpart = rpart(medv ~ ., Boston, subset = train)  

# Print a summary of the tree
summary(fit_rpart)
## Call:
## rpart(formula = medv ~ ., data = Boston, subset = train)
##   n= 253 
## 
##           CP nsplit rel error    xerror       xstd
## 1 0.55145333      0 1.0000000 1.0049875 0.12676397
## 2 0.17610053      1 0.4485467 0.5433114 0.06172420
## 3 0.05689532      2 0.2724461 0.3657167 0.05488317
## 4 0.04093613      3 0.2155508 0.3230569 0.04410254
## 5 0.03276814      4 0.1746147 0.2909722 0.04325331
## 6 0.01048773      5 0.1418465 0.2443818 0.04178581
## 7 0.01000000      6 0.1313588 0.2465690 0.04171043
## 
## Variable importance
##      rm   lstat    crim     dis     age     nox      zn   indus     tax ptratio 
##      34      23       9       7       6       5       4       4       4       4 
## 
## Node number 1: 253 observations,    complexity param=0.5514533
##   mean=21.78656, MSE=76.86907 
##   left son=2 (222 obs) right son=3 (31 obs)
##   Primary splits:
##       rm      < 6.9595   to the left,  improve=0.5514533, (0 missing)
##       lstat   < 8.13     to the right, improve=0.4710854, (0 missing)
##       ptratio < 19.65    to the right, improve=0.2687694, (0 missing)
##       indus   < 6.66     to the right, improve=0.2552622, (0 missing)
##       nox     < 0.5125   to the right, improve=0.2357242, (0 missing)
##   Surrogate splits:
##       lstat   < 4.6      to the right, agree=0.925, adj=0.387, (0 split)
##       indus   < 1.605    to the right, agree=0.893, adj=0.129, (0 split)
##       ptratio < 14.15    to the right, agree=0.893, adj=0.129, (0 split)
##       zn      < 86.25    to the left,  agree=0.889, adj=0.097, (0 split)
##       crim    < 0.01958  to the right, agree=0.885, adj=0.065, (0 split)
## 
## Node number 2: 222 observations,    complexity param=0.1761005
##   mean=19.3536, MSE=30.60492 
##   left son=4 (87 obs) right son=5 (135 obs)
##   Primary splits:
##       lstat < 14.405   to the right, improve=0.5040674, (0 missing)
##       crim  < 6.108565 to the right, improve=0.3586261, (0 missing)
##       dis   < 2.23995  to the left,  improve=0.3565564, (0 missing)
##       age   < 93.1     to the right, improve=0.3387831, (0 missing)
##       nox   < 0.5835   to the right, improve=0.3244859, (0 missing)
##   Surrogate splits:
##       age  < 88.1     to the right, agree=0.829, adj=0.563, (0 split)
##       dis  < 2.23995  to the left,  agree=0.815, adj=0.529, (0 split)
##       nox  < 0.5835   to the right, agree=0.779, adj=0.437, (0 split)
##       tax  < 431      to the right, agree=0.775, adj=0.425, (0 split)
##       crim < 5.24741  to the right, agree=0.770, adj=0.414, (0 split)
## 
## Node number 3: 31 observations,    complexity param=0.05689532
##   mean=39.20968, MSE=62.22539 
##   left son=6 (16 obs) right son=7 (15 obs)
##   Primary splits:
##       rm    < 7.553    to the left,  improve=0.5736135, (0 missing)
##       lstat < 4.52     to the right, improve=0.5430443, (0 missing)
##       dis   < 3.39285  to the right, improve=0.2565775, (0 missing)
##       crim  < 0.260035 to the left,  improve=0.2074717, (0 missing)
##       nox   < 0.4965   to the left,  improve=0.1890524, (0 missing)
##   Surrogate splits:
##       lstat < 4.52     to the right, agree=0.806, adj=0.600, (0 split)
##       crim  < 0.11276  to the left,  agree=0.742, adj=0.467, (0 split)
##       zn    < 27.5     to the right, agree=0.742, adj=0.467, (0 split)
##       dis   < 4.74095  to the right, agree=0.710, adj=0.400, (0 split)
##       nox   < 0.48     to the left,  agree=0.677, adj=0.333, (0 split)
## 
## Node number 4: 87 observations,    complexity param=0.03276814
##   mean=14.46092, MSE=17.85962 
##   left son=8 (26 obs) right son=9 (61 obs)
##   Primary splits:
##       crim  < 11.48635 to the right, improve=0.4101403, (0 missing)
##       lstat < 19.645   to the right, improve=0.3149683, (0 missing)
##       nox   < 0.6615   to the right, improve=0.2835297, (0 missing)
##       tax   < 551.5    to the right, improve=0.2799295, (0 missing)
##       dis   < 2.0037   to the left,  improve=0.2601665, (0 missing)
##   Surrogate splits:
##       age   < 99       to the right, agree=0.805, adj=0.346, (0 split)
##       dis   < 1.66345  to the left,  agree=0.793, adj=0.308, (0 split)
##       black < 221.785  to the left,  agree=0.782, adj=0.269, (0 split)
##       rm    < 5.3695   to the left,  agree=0.759, adj=0.192, (0 split)
##       lstat < 30.06    to the right, agree=0.747, adj=0.154, (0 split)
## 
## Node number 5: 135 observations,    complexity param=0.04093613
##   mean=22.50667, MSE=13.44981 
##   left son=10 (111 obs) right son=11 (24 obs)
##   Primary splits:
##       rm      < 6.543    to the left,  improve=0.4384591, (0 missing)
##       lstat   < 7.76     to the right, improve=0.3773263, (0 missing)
##       nox     < 0.5125   to the right, improve=0.1511574, (0 missing)
##       age     < 33.8     to the right, improve=0.1421214, (0 missing)
##       ptratio < 18.65    to the right, improve=0.1184029, (0 missing)
##   Surrogate splits:
##       lstat < 5.055    to the right, agree=0.874, adj=0.292, (0 split)
##       crim  < 0.02902  to the right, agree=0.830, adj=0.042, (0 split)
## 
## Node number 6: 16 observations
##   mean=33.425, MSE=31.59312 
## 
## Node number 7: 15 observations
##   mean=45.38, MSE=21.1336 
## 
## Node number 8: 26 observations
##   mean=10.31538, MSE=11.64284 
## 
## Node number 9: 61 observations,    complexity param=0.01048773
##   mean=16.22787, MSE=10.06234 
##   left son=18 (31 obs) right son=19 (30 obs)
##   Primary splits:
##       age     < 93.95    to the right, improve=0.3322959, (0 missing)
##       lstat   < 18.825   to the right, improve=0.2782521, (0 missing)
##       crim    < 0.711085 to the right, improve=0.2045230, (0 missing)
##       ptratio < 19.45    to the right, improve=0.1866394, (0 missing)
##       black   < 344.48   to the left,  improve=0.1796701, (0 missing)
##   Surrogate splits:
##       dis   < 2.2085   to the left,  agree=0.787, adj=0.567, (0 split)
##       indus < 16.01    to the right, agree=0.721, adj=0.433, (0 split)
##       nox   < 0.597    to the right, agree=0.721, adj=0.433, (0 split)
##       lstat < 17.995   to the right, agree=0.721, adj=0.433, (0 split)
##       crim  < 0.308165 to the right, agree=0.705, adj=0.400, (0 split)
## 
## Node number 10: 111 observations
##   mean=21.37748, MSE=6.875078 
## 
## Node number 11: 24 observations
##   mean=27.72917, MSE=10.68623 
## 
## Node number 18: 31 observations
##   mean=14.42903, MSE=5.294318 
## 
## Node number 19: 30 observations
##   mean=18.08667, MSE=8.190489
# Use tree for prediction
preds <- predict(fit_rpart, newdata = Boston[-train,])
test = Boston[-train, "medv"]

Let’s plot the resulting tree:

rpart.plot(fit_rpart, digits = 4)

We get the same predictions as we did using the tree function. We can also calculate the MSE and check that it is also the same:

mean((preds-test)^2)
## [1] 35.28688

Specifics of the regression tree algorithm

The recursive partitioning algorithm described above leads to a set of natural questions:

  1. When do we stop partitioning?

We stop when adding a partition does not reduce MSE, or, when a partition has too few training observations. Even then, trees built with this stopping criterion tend to overfit to the training data. To avoid this, a post-processing step called pruning is used to make the tree smaller.

  1. Why would a smaller tree tend to generalize better?

The cv.tree function is used to determine a reasonable tree depth for the given dataset. For this dataset it seems that a depth of 7 works well since it reaches the minimum error or “deviance” with that number:

set.seed(1)
cv_boston = cv.tree(fit)
plot(cv_boston$size, cv_boston$dev, type = 'b')

However, if we decide to prune a tree we can do so using the prune.tree() function:

prune_boston = prune.tree(fit, best = 5)
plot(prune_boston, type = "uniform")
text(prune_boston)

Let’s calculate the test set MSE for the pruned tree:

preds_prune <- predict(prune_boston, newdata = Boston[-train,])
test = Boston[-train, "medv"]
mean((preds_prune-test)^2)
## [1] 35.90102

The MSE of the pruned tree is greater than the original, so we should not prune this decision tree.

We can also prune a tree using the prune function from the rpart package. Just as we did with the tree function above, we can first plot the error for different tree sizes. The error is the smallest with 7 leaves or terminal nodes. But, we can prune the tree to 5 terminal nodes using a complexity parameter (cp) of 0.037.

set.seed(1)
fit_rpart = rpart(medv ~ ., Boston, subset = train)  

# Use tree for prediction
preds <- predict(fit_rpart, newdata = Boston[-train,])
test = Boston[-train, "medv"]

plotcp(fit_rpart)

min_cp = fit_rpart$cptable[which.min(fit_rpart$cptable[,"xerror"]),"CP"]
min_cp
## [1] 0.01048773
p <- prune(fit_rpart, cp = 0.037)
rpart.plot(p, digits = 4)

Classification (Decision) Trees

Classification, or decision trees, are used in classification problems where the outcome is categorical. The same partitioning principle is used, but now each region predicts the majority class for training observations within that region. The recursive partitioning algorithm we saw previously requires a score function to choose predictors (and values) to partition with. In classification we could use a naive approach of looking for partitions that minimize training error. However, better performing approaches use more sophisticated metrics. Here are two of the most popular (denoted for leaf \(m\)):

  • Gini Index: \(\sum_{k=1}^K \hat{p}_{mk}(1-\hat{p}_{mk})\), or

  • Entropy: \(-\sum_{k=1}^K \hat{p}_{mk}\log(\hat{p}_{mk})\)

where \(\hat{p}_{mk}\) is the proportion of training observations in partition \(m\) labeled as class \(k\). Both of these seek to partition observations into subsets that have the same labels.

Let us look at how a classification tree performs on the digits example we examined before:

We can see the prediction here:

We can again prune the tree if we wish, but in this case the pruned tree does not differ much from the original.

pruned_fit <- prune.tree(fit)
plot(pruned_fit)

Here is what a pruned tree looks like:

Classification trees have certain advantages that make them very useful. They are highly interpretable, even more so than linear models, are easy to visualize (if small enough), and they (maybe) model human decision processes and don’t require that dummy predictors for categorical variables are used.

On the other hand, the greedy approach via recursive partitioning is a bit harder to train than linear regression. It may not always be the best performing method since it is not very flexible and is highly unstable to changes in training data. Below we will learn about the bootstrap to help with this.

Bootstrap

Suppose the income distribution of your population is as follows:

hist(log10(income))

The population median is

m <- median(income)
m
## [1] 45500.41

Suppose we don’t have access to the entire population but want to estimate the median \(m\). We take a sample of 250 and estimate the population median \(m\) with the sample median \(M\):

set.seed(1)
N <- 250
X <- sample(income, N)
M <- median(X)
M
## [1] 46052.21

Can we construct a confidence interval? What is the distribution of \(M\)?

From a simulation we see that the distribution of \(M\) is approximately normal with the following expected value and standard error:

B <- 10^5
Ms <- replicate(B, {
  X <- sample(income, N)
  M <- median(X)
})
par(mfrow=c(1,2))
hist(Ms)
qqnorm(Ms)
qqline(Ms)

mean(Ms)
## [1] 45647.71
sd(Ms)
## [1] 3635.773

The problem here is that, as we have described before, in practice we do not have access to the distribution. In the past we have used the central limit theorem. But the CLT we studied applies to averages and here we are interested in the median.

The Bootstrap permits us to approximate a simulation without access to the entire distribution. The general idea is relatively simple. We act as if the sample is the distribution and sample (with replacement) datasets of the same size. Then we compute the summary statistic, in this case the median, on this bootstrap sample.

There is theory telling us that the distribution of the statistics obtained with bootstrap samples approximate the distribution of our actual statistic. This is how we construct bootstrap samples and an approximate distribution:

B <- 10^5
M_stars <- replicate(B, {
  X_star <- sample(X, N, replace = TRUE)
  M_star <- median(X_star)
})

Now we can check how close it is to the actual distribution

qqplot(Ms, M_stars)
abline(0,1)  

We see it is not perfect but it provides a decent approximation:

quantile(Ms, c(0.05, 0.95))
##       5%      95% 
## 39901.85 51855.70
quantile(M_stars, c(0.05, 0.95))
##       5%      95% 
## 42227.29 54750.95

This is much better than what we get if we mindlessly use the CLT:

median(X) + 1.96 * sd(X)/sqrt(N) * c(-1,1)
## [1] 36880.41 55224.02

If we know the distribution is normal, we can use the bootstrap to estimate the mean:

mean(Ms) + 1.96*sd(Ms)*c(-1,1)
## [1] 38521.60 52773.83
mean(M_stars) + 1.96*sd(M_stars)*c(-1,1)
## [1] 38539.41 55109.06

Random Forests

Random Forests are a very popular approach that address the shortcomings of decision trees via re-sampling of the training data. Their goal is to improve prediction performance and reduce instability by averaging multiple decision trees (a forest constructed with randomness). It has two features that help accomplish this.

The first trick is Bagging (bootstrap aggregation) General scheme:

  1. Build many decision trees \(T_1, T_2, \ldots, T_B\) from training set
  2. Given a new observation, let each \(T_j\) predict \(\hat{y}_j\)
  3. For regression: predict average \(\frac{1}{B} \sum_{j=1}^B \hat{y}_j\), for classification: predict with majority vote (most frequent class)

But how do we get many decision trees from a single training set?

For this we use the bootstrap. To create \(T_j, \, j=1,\ldots,B\) from a training set of size \(N\):

  1. Create a bootstrap training set by sampling \(N\) observations from training set with replacement
  2. Build a decision tree from the bootstrap training set

Let’s look at this using the Boston housing dataset. We fit a Random Forest by using the randomForest() function. Here, mtry = 13 indicates all 13 predictors should be considered for each split of the tree - in other words, bagging should be done.

library(randomForest)
library(MASS)
fit_bag <- randomForest(medv ~ ., data = Boston[train,], mtry = 13)
fit_bag
## 
## Call:
##  randomForest(formula = medv ~ ., data = Boston[train, ], mtry = 13) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 13
## 
##           Mean of squared residuals: 11.45798
##                     % Var explained: 85.09

How well does it perform on the test set?

preds_bag = predict(fit_bag, newdata = Boston[-train,])
plot(preds_bag, test)
abline(0,1)

mean((preds_bag - test)^2)
## [1] 23.48487

Much, much better than the one decision tree we used above. With Random Forest we get an MSE of 23.4 which (after square rooting) translates to our predictions being within $4,847 of the true median home value for the suburb. This is much lower than the $5,940 we got when using the one decision tree.

We can grow a random forest in the same way, except now we use a smaller value for mtry. The default for regression trees is \(p/3\) (\(p\) is the number of predictors) and the default for classification trees is \(\sqrt p\). Let’s try mtry = 6.

fit_rf <- randomForest(medv ~ ., data = Boston[train,], mtry = 6, importance = TRUE)
preds_rf = predict(fit_rf, newdata = Boston[-train,])
mean((preds_rf - test)^2)
## [1] 19.20602

Looks like the MSE for this random forest is better with fewer predictors and the one using all of the predictors performs worse.

The second Random Forest feature is to use a random selection of features to split when deciding partitions. Specifically, when building each tree \(T_j\), at each recursive partition only consider a randomly selected subset of predictors to check for the best split. This reduces correlation between trees in a forest, improving prediction accuracy.

Here is the random forest fit for the digits data:

detach("package:MASS", unload=TRUE)
library(randomForest)
fit <- randomForest(label ~ X_1 + X_2, data = digits_train)

We can control the “smoothness” of the random forest estimate in several ways. One way is to limit the size of each node. We can require the number of points per node to be larger:

fit <- randomForest(as.factor(label) ~ X_1  +X_2,
                    nodesize = 250,
                    data = digits_train)

We can compare the results:

library(caret)
get_accuracy <- function(fit){
  pred <- predict(fit, newdata = digits_test, type = "class")
  confusionMatrix(table(pred = pred, true = digits_test$label))$overall[1]
}
fit <- tree(label ~ X_1 + X_2, data = digits_train)
get_accuracy(fit)
##  Accuracy 
## 0.8036381
fit <- randomForest(label ~ X_1 + X_2, data = digits_train)
get_accuracy(fit)
##  Accuracy 
## 0.8008396
fit <- randomForest(label ~ X_1 + X_2,
                    nodesize = 250,
                    data = digits_train)
get_accuracy(fit)
##  Accuracy 
## 0.8211287

A disadvantage of random forests is that we lose interpretability. However, we can use the fact that a bootstrap sample was used to construct trees to measure variable importance from the random forest. The importance of a variable translates to how predictive that variable is.

Let’s see this using all the digits data:

url <- "https://raw.githubusercontent.com/datasciencelabs/data/master/hand-written-digits-train.csv"
digits <- read_csv(url)

digits <- mutate(digits, label = as.factor(label))

inTrain   <- createDataPartition(y = digits$label, p = 0.9, times = 1, list = FALSE)
train_set <- slice(digits, inTrain)
test_set  <- slice(digits, -inTrain)

fit <- randomForest(label~., ntree = 100, data = train_set)

How well does it do?

pred <- predict(fit, newdata = test_set, type = "class")
confusionMatrix(table(pred = pred, true = test_set$label))
## Confusion Matrix and Statistics
## 
##     true
## pred   0   1   2   3   4   5   6   7   8   9
##    0 409   0   3   0   0   2   3   0   1   1
##    1   0 465   0   1   1   1   1   1   3   1
##    2   0   0 410   5   1   1   0   7   3   0
##    3   0   0   1 410   1   6   0   1   3   6
##    4   0   1   0   0 393   3   1   5   2   6
##    5   0   0   0   4   0 359   1   1   0   1
##    6   2   1   0   1   3   3 405   0   1   1
##    7   0   1   3   5   1   0   0 421   0   4
##    8   2   0   0   7   0   3   2   2 390   1
##    9   0   0   0   2   7   1   0   2   3 397
## 
## Overall Statistics
##                                           
##                Accuracy : 0.9673          
##                  95% CI : (0.9615, 0.9725)
##     No Information Rate : 0.1115          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.9637          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5
## Sensitivity           0.99031   0.9936  0.98321  0.94253  0.96560  0.94723
## Specificity           0.99736   0.9976  0.99550  0.99521  0.99525  0.99817
## Pos Pred Value        0.97613   0.9810  0.96019  0.95794  0.95620  0.98087
## Neg Pred Value        0.99894   0.9992  0.99814  0.99337  0.99630  0.99478
## Prevalence            0.09843   0.1115  0.09938  0.10367  0.09700  0.09032
## Detection Rate        0.09747   0.1108  0.09771  0.09771  0.09366  0.08556
## Detection Prevalence  0.09986   0.1130  0.10176  0.10200  0.09795  0.08723
## Balanced Accuracy     0.99384   0.9956  0.98936  0.96887  0.98043  0.97270
##                      Class: 6 Class: 7 Class: 8 Class: 9
## Sensitivity           0.98063   0.9568  0.96059  0.94976
## Specificity           0.99683   0.9963  0.99551  0.99603
## Pos Pred Value        0.97122   0.9678  0.95823  0.96359
## Neg Pred Value        0.99788   0.9949  0.99578  0.99445
## Prevalence            0.09843   0.1049  0.09676  0.09962
## Detection Rate        0.09652   0.1003  0.09295  0.09461
## Detection Prevalence  0.09938   0.1037  0.09700  0.09819
## Balanced Accuracy     0.98873   0.9765  0.97805  0.97290

Here is a table of variable importance for the random forest we just constructed.

library(knitr)
variable_importance <- importance(fit) 
tmp <- tibble(feature = rownames(variable_importance),
                  Gini = variable_importance[,1]) %>%
                  arrange(desc(Gini))
kable(tmp[1:10,])
feature Gini
pixel347 303.4591
pixel406 279.1607
pixel378 273.9041
pixel433 263.1719
pixel350 255.2681
pixel409 252.5079
pixel461 244.0507
pixel437 241.0175
pixel489 235.2015
pixel405 231.7118

We can see where the “important” features are:

expand.grid(Row = 1:28, Column = 1:28) %>%
            mutate(value = variable_importance[,1]) %>%
            ggplot(aes(Row, Column, fill = value)) +
            geom_raster() +
            scale_y_reverse() 

And a barplot of the same data showing only the most predictive features:

tmp %>% filter(Gini > 200) %>%
        ggplot(aes(x=reorder(feature, Gini), y=Gini)) +
        geom_bar(stat='identity') +
        coord_flip() + xlab("Feature") +
        theme(axis.text=element_text(size=8))

Tree-based methods summary

Tree-based methods are very interpretable prediction models for which some inferential tasks are possible (e.g., variable importance in random forests), but are much more limited than the linear models we saw previously. These methods are very commonly used across many application domains and Random Forests often perform at state-of-the-art for many tasks.